{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Embeddings API Examples with MLX Server"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates how to use the embeddings endpoint of MLX Server through the OpenAI-compatible API. You'll learn how to generate embeddings, work with batches, compare similarity between texts, and use embeddings for practical applications."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup and Connection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the OpenAI client for API communication\n",
    "from openai import OpenAI\n",
    "\n",
    "# Connect to the local MLX Server with OpenAI-compatible API\n",
    "client = OpenAI(\n",
    "    base_url=\"http://localhost:8000/v1\",\n",
    "    api_key=\"fake-api-key\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic Embedding Generation\n",
    "\n",
    "### Single Text Embedding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate embedding for a single text input\n",
    "single_text = \"Artificial intelligence is transforming how we interact with technology.\"\n",
    "response = client.embeddings.create(\n",
    "    input=[single_text],\n",
    "    model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batch Processing Multiple Texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_batch = [\n",
    "    \"Machine learning algorithms improve with more data\",\n",
    "    \"Natural language processing helps computers understand human language\",\n",
    "    \"Computer vision allows machines to interpret visual information\"\n",
    "]\n",
    "\n",
    "batch_response = client.embeddings.create(\n",
    "    input=text_batch,\n",
    "    model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of embeddings generated: 3\n",
      "Dimensions of each embedding: 1536\n"
     ]
    }
   ],
   "source": [
    "# Access all embeddings\n",
    "embeddings = [item.embedding for item in batch_response.data]\n",
    "print(f\"Number of embeddings generated: {len(embeddings)}\")\n",
    "print(f\"Dimensions of each embedding: {len(embeddings[0])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Semantic Similarity Calculation\n",
    "\n",
    "One of the most common uses for embeddings is measuring semantic similarity between texts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cosine_similarity_score(vec1, vec2):\n",
    "    \"\"\"Calculate cosine similarity between two vectors\"\"\"\n",
    "    dot_product = np.dot(vec1, vec2)\n",
    "    norm1 = np.linalg.norm(vec1)\n",
    "    norm2 = np.linalg.norm(vec2)\n",
    "    return dot_product / (norm1 * norm2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example texts to compare\n",
    "text1 = \"Dogs are loyal pets that provide companionship\"\n",
    "text2 = \"Canines make friendly companions for humans\"\n",
    "text3 = \"Quantum physics explores the behavior of matter at atomic scales\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate embeddings\n",
    "comparison_texts = [text1, text2, text3]\n",
    "comparison_response = client.embeddings.create(\n",
    "    input=comparison_texts,\n",
    "    model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
    ")\n",
    "comparison_embeddings = [item.embedding for item in comparison_response.data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Similarity between text1 and text2: 0.8142\n",
      "Similarity between text1 and text3: 0.6082\n",
      "Similarity between text2 and text3: 0.5739\n"
     ]
    }
   ],
   "source": [
    "# Compare similarities\n",
    "similarity_1_2 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[1])\n",
    "similarity_1_3 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[2])\n",
    "similarity_2_3 = cosine_similarity_score(comparison_embeddings[1], comparison_embeddings[2])\n",
    "\n",
    "print(f\"Similarity between text1 and text2: {similarity_1_2:.4f}\")\n",
    "print(f\"Similarity between text1 and text3: {similarity_1_3:.4f}\")\n",
    "print(f\"Similarity between text2 and text3: {similarity_2_3:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text Search Using Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample document collection\n",
    "documents = [\n",
    "    \"The quick brown fox jumps over the lazy dog\",\n",
    "    \"Machine learning models require training data\",\n",
    "    \"Neural networks are inspired by biological neurons\",\n",
    "    \"Deep learning is a subset of machine learning\",\n",
    "    \"Natural language processing helps with text analysis\",\n",
    "    \"Computer vision systems can detect objects in images\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate embeddings for all documents\n",
    "doc_response = client.embeddings.create(\n",
    "    input=documents,\n",
    "    model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
    ")\n",
    "doc_embeddings = [item.embedding for item in doc_response.data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Search results:\n",
      "Score: 0.8574 - Computer vision systems can detect objects in images\n",
      "Score: 0.8356 - Neural networks are inspired by biological neurons\n",
      "Score: 0.8266 - Natural language processing helps with text analysis\n",
      "Score: 0.8141 - Deep learning is a subset of machine learning\n",
      "Score: 0.7474 - Machine learning models require training data\n",
      "Score: 0.5936 - The quick brown fox jumps over the lazy dog\n"
     ]
    }
   ],
   "source": [
    "def search_documents(query, doc_collection, doc_embeddings):\n",
    "    \"\"\"Search for documents similar to query\"\"\"\n",
    "    # Generate embedding for query\n",
    "    query_response = client.embeddings.create(\n",
    "        input=[query],\n",
    "        model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
    "    )\n",
    "    query_embedding = query_response.data[0].embedding\n",
    "    \n",
    "    # Calculate similarity scores\n",
    "    similarities = []\n",
    "    for doc_embedding in doc_embeddings:\n",
    "        similarity = cosine_similarity_score(query_embedding, doc_embedding)\n",
    "        similarities.append(similarity)\n",
    "    \n",
    "    # Return results with scores\n",
    "    results = []\n",
    "    for i, score in enumerate(similarities):\n",
    "        results.append((doc_collection[i], score))\n",
    "    \n",
    "    # Sort by similarity score (highest first)\n",
    "    return sorted(results, key=lambda x: x[1], reverse=True)\n",
    "\n",
    "# Example search\n",
    "search_results = search_documents(\"How do AI models learn?\", documents, doc_embeddings)\n",
    "\n",
    "print(\"Search results:\")\n",
    "for doc, score in search_results:\n",
    "    print(f\"Score: {score:.4f} - {doc}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
